"""
OpenSeesVispy 所使用的函数以及数据
"""

import h5py
import numpy as np
import openseespy.opensees as ops

# 如下为OpenSees中的单元类识别号，用来判断单元类型
ELE_TAG_Truss = (12, 13, 14, 15, 16, 17, 18, 169)  # 169 CatenaryCable
ELE_TAG_Link = (19, 20, 21, 22, 23, 24, 25, 26, 260, 27,  # zeroLength
                86,  # 86为 twoNodeLink
                84, 85, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104,
                105, 106, 107, 108, 109, 130, 131, 132, 147, 148, 149, 150, 151, 152, 153,
                158, 159, 160, 161, 165, 166)  # Bearing
ELE_TAG_Beam = (3, 4, 5, 5001, 6, 7, 8, 9, 10, 11, 28, 29, 30, 34, 35, 62, 621,
                63, 64, 640, 641, 642,
                65, 66, 67, 68, 69, 70, 71, 72, 73, 731, 74, 75, 751, 76, 77, 78,
                30766, 30765, 30767, 79, 128,)
ELE_TAG_Plane = (31, 32, 33, 40, 47, 50, 52, 53, 54, 55, 59, 61, 119, 120, 126, 134,
                 156, 157, 167, 168, 173, 174, 175, 203, 204, 207, 208, 209)
ELE_TAG_Joint = (71, 72, 81, 8181, 82, 83,)
ELE_TAG_Tetrahedron = (179,)
ELE_TAG_Brick = (36, 37, 38, 39, 41, 42, 43, 44, 45, 46, 48, 49, 51, 56, 57, 58,
                 121, 122, 127)


def _check_file(file_name, file_type):
    if not file_name.endswith(file_type):
        raise ValueError(f"file must be endswith {file_type}!")


class model_base:
    """
    模型数据获取基类，不要实例化
    """

    def __init__(self):

        # 初始化模型数据
        self.model_info_names = ['coord_no_deform', 'coord_ele_midpoints',
                                 'bound', 'max_bound', 'num_node', 'num_ele', 'NodeTags', 'EleTags',
                                 'model_dims']
        self.model_info = dict()
        for name in self.model_info_names:
            self.model_info[name] = None

        # 初始化单元连接数据
        self.cells_names = ['truss', 'link', 'beam', 'other_line', 'all_lines',
                            'plane', 'tetrahedron', 'brick', 'all_faces']
        self.cells = dict()
        for name in self.cells_names:
            self.cells[name] = None

        self.reset_eigen_state()
        self.reset_steps_state()

    def reset_eigen_state(self):
        # 初始化特征值数据
        self.eigen_names = ['f', 'eigenvector'] + self.model_info_names + self.cells_names
        self.eigen = dict()
        for name in self.eigen_names:
            self.eigen[name] = list()

    def reset_steps_state(self):
        """
        Reset the state of results extract
        """
        # 分析步模型更新数据
        self.model_info_steps = dict()
        for name in self.model_info_names:
            self.model_info_steps[name] = []

        # 单元连接更新数据
        self.cells_steps = dict()
        for name in self.cells_names:
            self.cells_steps[name] = []

        # 节点响应更新数据
        self.node_resp_names = ('disp', 'vel', 'accel')
        self.node_resp_steps = dict()
        for name in self.node_resp_names:
            self.node_resp_steps[name] = []

        # 桁架单元结果分析步响应数据

        # 梁单元结果分析步响应数据

        # 纤维截面分析步响应数据

        self.model_update = False

        self.step_track = 0

    def get_model_data(self, output_file=None):
        """
        用来获取有限元模型的数据
        :param output_file: 输出的hdf5文件
        """
        if output_file is not None:
            _check_file(output_file, '.hdf5')
        # Get all the node tags
        NodeTags = ops.getNodeTags()
        NodeTags.sort()
        num_node = len(NodeTags)
        # Get all the ele tags
        EleTags = ops.getEleTags()
        EleTags.sort()
        num_ele = len(EleTags)
        # Dict of node coordinates
        Node_Coords = np.zeros((num_node, 3))
        Node_Index = dict()  # key: nodeTag, value: index in Node_Coords
        model_dims = []

        for i, Tag in enumerate(NodeTags):
            coord = ops.nodeCoord(Tag)
            model_dim = len(coord)
            if model_dim == 1:
                coord.extend([0, 0])
            elif model_dim == 2:
                coord.extend([0])
            model_dims.append(model_dim)
            Node_Coords[i] = np.array(coord)
            Node_Index[Tag] = i
        points = Node_Coords

        # lines and faces
        # 单元的连接方式，节点数目n-节点号1-节点号2-...-节点号n
        truss_cells = []
        link_cells = []
        beam_cells = []
        other_line_cells = []
        all_lines_cells = []
        plane_cells = []
        tetrahedron_cells = []
        brick_cells = []
        all_faces_cells = []
        ele_midpoints = []  # 单元中点坐标
        for i, ele in enumerate(EleTags):
            ele_nodes = ops.eleNodes(ele)
            # 以下根据单元节点数目来判断单元类型
            if len(ele_nodes) == 2:
                nodeI, nodeJ = ele_nodes
                idxI, idxJ = Node_Index[nodeI], Node_Index[nodeJ]
                all_lines_cells.append([2, idxI, idxJ])
                if ops.getEleClassTags(ele)[0] in ELE_TAG_Truss:
                    truss_cells.append([2, idxI, idxJ])
                elif ops.getEleClassTags(ele)[0] in ELE_TAG_Link:
                    link_cells.append([2, idxI, idxJ])
                elif ops.getEleClassTags(ele)[0] in ELE_TAG_Beam:
                    beam_cells.append([2, idxI, idxJ])
                else:
                    other_line_cells.append([2, idxI, idxJ])
                ele_midpoints.append((Node_Coords[idxI] + Node_Coords[idxJ]) / 2)

            elif len(ele_nodes) == 3:
                nodeI, nodeJ, nodeK = ops.eleNodes(ele)
                idxI, idxJ, idxK = Node_Index[nodeI], Node_Index[nodeJ], Node_Index[nodeK]
                all_faces_cells.append([3, idxI, idxJ, idxK])
                plane_cells.append([3, idxI, idxJ, idxK])
                ele_midpoints.append((Node_Coords[idxI] + Node_Coords[idxJ] + Node_Coords[idxK]) / 3)

            elif len(ele_nodes) == 4 or len(ele_nodes) == 9:
                if len(ele_nodes) == 4:
                    nodeI, nodeJ, nodeK, nodeL = ops.eleNodes(ele)
                else:
                    nodeI, nodeJ, nodeK, nodeL = ops.eleNodes(ele)[0:4]
                idxI, idxJ, idxK, idxL = Node_Index[nodeI], Node_Index[nodeJ], Node_Index[nodeK], Node_Index[nodeL]
                if ops.getEleClassTags(ele)[0] in ELE_TAG_Tetrahedron:  # 四面体
                    tetrahedron_cells.append([3, idxI, idxJ, idxK])
                    tetrahedron_cells.append([3, idxI, idxJ, idxL])
                    tetrahedron_cells.append([3, idxI, idxK, idxL])
                    tetrahedron_cells.append([3, idxJ, idxK, idxL])
                    all_faces_cells.append([3, idxJ, idxK, idxL])
                else:
                    plane_cells.append([4, idxI, idxJ, idxK, idxL])
                    all_faces_cells.append([4, idxI, idxJ, idxK, idxL])
                ele_midpoints.append(
                    (Node_Coords[idxI] + Node_Coords[idxJ] + Node_Coords[idxK] + Node_Coords[idxL]) / 4)
            elif len(ele_nodes) == 8 or len(ele_nodes) == 20:
                if len(ele_nodes) == 8:
                    node1, node2, node3, node4, node5, node6, node7, node8 = ops.eleNodes(ele)
                else:
                    node1, node2, node3, node4, node5, node6, node7, node8 = ops.eleNodes(ele)[0:8]
                tag_list = [Node_Index[node1], Node_Index[node2], Node_Index[node3], Node_Index[node4],
                            Node_Index[node5], Node_Index[node6], Node_Index[node7], Node_Index[node8]]
                temp_points = np.array([Node_Coords[i] for i in tag_list])
                idxxx = np.argsort(temp_points[:, -1])
                temp_points = temp_points[idxxx]
                tag_list = np.array(tag_list)[idxxx]
                temp_points = [tuple(i) for i in temp_points]
                tag_list = list(tag_list)
                tag_counter1 = counter_clockwise(temp_points[:4], tag_list[:4])  # 逆时针排序
                tag_counter2 = counter_clockwise(temp_points[4:], tag_list[4:])  # 逆时针排序
                idx1, idx2, idx3, idx4, idx5, idx6, idx7, idx8 = tag_counter1 + tag_counter2
                brick_cells.append([4, idx1, idx4, idx3, idx2])
                brick_cells.append([4, idx5, idx6, idx7, idx8])
                brick_cells.append([4, idx1, idx5, idx8, idx4])
                brick_cells.append([4, idx2, idx3, idx7, idx6])
                brick_cells.append([4, idx1, idx2, idx6, idx5])
                brick_cells.append([4, idx3, idx4, idx8, idx7])
                all_faces_cells.append([4, idx1, idx4, idx3, idx2])
                all_faces_cells.append([4, idx5, idx6, idx7, idx8])
                all_faces_cells.append([4, idx1, idx5, idx8, idx4])
                all_faces_cells.append([4, idx2, idx3, idx7, idx6])
                all_faces_cells.append([4, idx1, idx2, idx6, idx5])
                all_faces_cells.append([4, idx3, idx4, idx8, idx7])
                sum1 = Node_Coords[idx1] + Node_Coords[idx2] + Node_Coords[idx3] + Node_Coords[idx4]
                sum2 = Node_Coords[idx5] + Node_Coords[idx6] + Node_Coords[idx7] + Node_Coords[idx8]
                ele_midpoints.append((sum1 + sum2) / 8)
        ele_midpoints = np.array(ele_midpoints)

        # 自动根据模型节点坐标判断坐标轴边框的位置
        minNode = np.min(points, axis=0)
        maxNode = np.max(points, axis=0)
        space = (maxNode - minNode) / 15
        minNode = minNode - 2 * space
        maxNode = maxNode + 2 * space
        bounds = [minNode[0], maxNode[0], minNode[1], maxNode[1], minNode[2], maxNode[2]]
        max_bound = np.max(maxNode - minNode)

        # 有限元数据，包括点，线，面，实体的点连接
        self.model_info['coord_no_deform'] = points
        self.model_info['coord_ele_midpoints'] = ele_midpoints
        self.model_info['bound'] = bounds
        self.model_info['max_bound'] = max_bound
        self.model_info['num_node'] = num_node
        self.model_info['num_ele'] = num_ele
        self.model_info['NodeTags'] = NodeTags
        self.model_info['EleTags'] = EleTags
        self.model_info['model_dims'] = model_dims
        self.model_info['coord_ele_midpoints'] = ele_midpoints

        self.cells['all_lines'] = all_lines_cells
        self.cells['all_faces'] = all_faces_cells
        self.cells['plane'] = plane_cells
        self.cells['tetrahedron'] = tetrahedron_cells
        self.cells['brick'] = brick_cells
        self.cells['truss'] = truss_cells
        self.cells['link'] = link_cells
        self.cells['beam'] = beam_cells
        self.cells['other_line'] = other_line_cells

        self.model_data_finished = True

        if output_file:
            with h5py.File(output_file, 'w') as f:
                grp1 = f.create_group("ModelInfo")
                for name in self.model_info_names:
                    grp1.create_dataset(name, data=self.model_info[name])
                grp2 = f.create_group("Cell")
                for name in self.cells_names:
                    grp2.create_dataset(name, data=self.cells[name])

    def get_eigen_data(self, modeTag=1, serialize=False, output_file=None):
        """
        本函数用来计算特征值，并保存相应数据
        :param modeTag: 标量，模态号
        :param serialize: 当modeTag大于1时是否保存modeTag以前所有的模态信息
        :param output_file: 输出的hdf5文件
        :return: None
        """
        # ----------------------------------
        if output_file is not None:
            _check_file(output_file, '.hdf5')
        # ----------------------------------
        self.get_model_data()
        self.reset_eigen_state()
        # ----------------------------------
        if modeTag == 1:
            serialize = False
        if serialize:
            for mode_tag in range(1, modeTag + 1):
                ops.wipeAnalysis()
                if mode_tag == 1:
                    eigenValues = ops.eigen(2)[:1]
                else:
                    eigenValues = ops.eigen(mode_tag)
                omega = np.sqrt(eigenValues)
                f = omega / (2 * np.pi)
                T = 1 / f
                self.eigen['f'].append(f[-1])
                # ------------------------------------------
                eigen_vector = np.zeros((self.model_info['num_node'], 3))
                for i, Tag in enumerate(self.model_info['NodeTags']):
                    coord = ops.nodeCoord(Tag)
                    eigen = ops.nodeEigenvector(Tag, mode_tag)
                    if len(coord) == 1:
                        coord.extend([0, 0])
                        eigen.extend([0, 0])
                    elif len(coord) == 2:
                        coord.extend([0])
                        eigen = eigen[:2]
                        eigen.extend([0])
                    else:
                        eigen = eigen[:3]
                    eigen_vector[i] = np.array(eigen)
                self.eigen['eigenvector'].append(eigen_vector)

        else:  # 仅提取单一模态
            ops.wipeAnalysis()
            if modeTag == 1:
                eigenValues = ops.eigen(2)[:1]
            else:
                eigenValues = ops.eigen(modeTag)
            omega = np.sqrt(eigenValues)
            f = omega / (2 * np.pi)
            self.eigen['f'] = f[-1]
            # ------------------------------------------
            eigen_vector = np.zeros((self.model_info['num_node'], 3))
            for i, Tag in enumerate(self.model_info['NodeTags']):
                coord = ops.nodeCoord(Tag)
                eigen = ops.nodeEigenvector(Tag, modeTag)
                if len(coord) == 1:
                    coord.extend([0, 0])
                    eigen.extend([0, 0])
                elif len(coord) == 2:
                    coord.extend([0])
                    eigen = eigen[:2]
                    eigen.extend([0])
                else:
                    eigen = eigen[:3]
                eigen_vector[i] = np.array(eigen)
            self.eigen['eigenvector'] = eigen_vector
        self.eigen.update(self.model_info)
        self.eigen.update(self.cells)
        # ----------------------------------------------------------------
        if output_file:
            with h5py.File(output_file, 'w') as f:
                grp = f.create_group("EigenInfo")
                for name in self.eigen_names:
                    grp.create_dataset(name, data=self.eigen[name])

    def get_step_data(self, model_update: bool = False, output_file: str = None, num_steps: int = None):
        """
        :param model_update: 是否每一分析步都更新模型域数据
        :param output_file: 保存模型数据的文件名，会保存所有的分析步结果
        :param num_steps: 是分析总步数, 设置output_file后必须设置, 以确定何时保存数据
        """

        if model_update:
            self.get_model_data()
        else:
            if not self.model_data_finished:
                self.get_model_data()
        if output_file:
            _check_file(output_file, '.hdf5')
            if num_steps is None:
                raise ValueError("When output_file is set, num_steps must be input")

        num_node = self.model_info['num_node']
        # num_ele = self.num_ele
        NodeTags = self.model_info['NodeTags']
        # EleTags = self.EleTags
        # 用来储存每个时间步每个节点的位移数据，索引是时间步,节点, 坐标维度
        Node_Disp = np.zeros((num_node, 3))
        Node_Vel = np.zeros((num_node, 3))
        Node_Accel = np.zeros((num_node, 3))
        # 用来储存每个时间步每个节点的变形后坐标数据，索引是时间步,节点, 坐标维度
        Node_Deform_Coord = np.zeros((num_node, 3))

        for i, Tag in enumerate(NodeTags):
            Coord = ops.nodeCoord(Tag)
            Disp = ops.nodeDisp(Tag)
            Vel = ops.nodeVel(Tag)
            Accel = ops.nodeAccel(Tag)
            # 根据坐标点数目判断维度
            if len(Coord) == 1:
                Coord.extend([0, 0])
                Disp.extend([0, 0])
                Vel.extend([0, 0])
                Accel.extend([0, 0])
            elif len(Coord) == 2:
                Coord.extend([0])
                Disp = Disp[:2]
                Disp.extend([0])
                Vel = Vel[:2]
                Vel.extend([0])
                Accel = Accel[:2]
                Accel.extend([0])
            else:
                Disp = Disp[:3]  # 忽略转动
                Vel = Vel[:3]  # 忽略转动
                Accel = Disp[:3]  # 忽略转动
            Node_Disp[i] = Disp
            Node_Vel[i] = Vel
            Node_Accel[i] = Accel
            Node_Deform_Coord[i] = [Disp[ii] + Coord[ii] for ii in range(3)]

        self.node_resp_steps['disp'].append(Node_Disp)
        self.node_resp_steps['vel'].append(Node_Vel)
        self.node_resp_steps['accel'].append(Node_Accel)
        self.model_update = model_update
        if model_update:
            for name in self.model_info_names:
                self.model_info_steps[name].append(self.model_info[name])
            for name in self.cells_names:
                self.cells_steps[name].append(self.cells[name])
        else:
            if self.step_track == 0:
                self.model_info_steps.update(self.model_info)
                self.cells_steps.update(self.cells)
        # ----------------------------------------------------------------
        self.step_track += 1
        if self.step_track == num_steps:
            if output_file:
                with h5py.File(output_file, 'w') as f:
                    grp1 = f.create_group("ModelInfoSteps")
                    for name in self.model_info_names:
                        grp1.create_dataset(name, data=self.model_info_steps[name])
                    grp2 = f.create_group("CellSteps")
                    for name in self.cells_names:
                        grp2.create_dataset(name, data=self.cells_steps[name])
                    grp3 = f.create_group("NodeRespSteps")
                    for name in self.node_resp_names:
                        grp3.create_dataset(name, data=self.node_resp_steps[name])


def sort_xy(x, y):
    # 逆时针排序点
    x0 = np.mean(x)
    y0 = np.mean(y)
    r = np.sqrt((x - x0) ** 2 + (y - y0) ** 2)
    angles = np.where((y - y0) >= 0, np.arccos((x - x0) / r),
                      4 * np.pi - np.arccos((x - x0) / r))
    mask = np.argsort(angles)
    x_max = np.max(x)
    if x[mask][0] != x_max:
        mask = np.roll(mask, 1)
    return mask


def counter_clockwise(points, tag):
    """
    用来将一个面上的点逆时针排序
    """
    x = np.array([point[0] for point in points])
    y = np.array([point[1] for point in points])
    z = np.array([point[2] for point in points])

    if all(np.abs(x - x[0]) < 1e-6):  # yz平面
        # def algo(point):
        #    return (math.atan2(point[2] - z_center, point[1] - y_center) + 2 * math.pi) % (2*math.pi)
        # sorted_points = sorted(points,key = algo )
        index = sort_xy(y, z)
        sorted_points = list(zip(x[index], y[index], z[index]))
        sorted_id = [points.index(i) for i in sorted_points]
        sorted_tag = [tag[i] for i in sorted_id]
    elif all(np.abs(y - y[0]) < 1e-6):  # xz平面
        # def algo(point):
        #    return (math.atan2(point[2] - z_center, point[0] - x_center) + 2 * math.pi) % (2*math.pi)
        # sorted_points = sorted(points,key = algo )
        index = sort_xy(x, z)
        sorted_points = list(zip(x[index], y[index], z[index]))
        sorted_id = [points.index(i) for i in sorted_points]
        sorted_tag = [tag[i] for i in sorted_id]
    else:
        # def algo(point):
        #    return (math.atan2(point[1] - y_center, point[0] - x_center) + 2 * math.pi) % (2*math.pi)
        # sorted_points = sorted(points,key = algo )
        index = sort_xy(x, y)
        sorted_points = list(zip(x[index], y[index], z[index]))
        sorted_id = [points.index(i) for i in sorted_points]
        sorted_tag = [tag[i] for i in sorted_id]
    return sorted_tag
